package com.example.security;

import com.example.config.RateLimitService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import jakarta.servlet.http.HttpServletRequest;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * API滥用防护组件
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class ApiAbuseProtection {

    private final RateLimitService rateLimitService;
    private final RedisTemplate<String, Object> redisTemplate;

    /**
     * 检查API滥用
     */
    public boolean checkApiAbuse(String userId, String apiPath, HttpServletRequest request) {
        try {
            String clientIp = getClientIp(request);
            
            // 1. 检查用户API调用频率
            if (!checkUserApiRate(userId)) {
                recordAbuseEvent(userId, clientIp, "USER_API_RATE_EXCEEDED", apiPath);
                return false;
            }
            
            // 2. 检查IP API调用频率
            if (!checkIpApiRate(clientIp)) {
                recordAbuseEvent(userId, clientIp, "IP_API_RATE_EXCEEDED", apiPath);
                return false;
            }
            
            // 3. 检查特定API端点频率
            if (!checkEndpointRate(userId, apiPath)) {
                recordAbuseEvent(userId, clientIp, "ENDPOINT_RATE_EXCEEDED", apiPath);
                return false;
            }
            
            // 4. 检查异常行为模式
            if (detectAbnormalPattern(userId, apiPath, request)) {
                recordAbuseEvent(userId, clientIp, "ABNORMAL_PATTERN", apiPath);
                return false;
            }
            
            // 5. 检查是否在黑名单中
            if (isInBlacklist(userId, clientIp)) {
                recordAbuseEvent(userId, clientIp, "BLACKLISTED", apiPath);
                return false;
            }
            
            // 记录正常API调用
            recordApiCall(userId, clientIp, apiPath);
            
            return true;
            
        } catch (Exception e) {
            log.error("检查API滥用异常: {}", e.getMessage(), e);
            return true; // 异常情况下允许通过
        }
    }

    /**
     * 检查用户API调用频率
     */
    private boolean checkUserApiRate(String userId) {
        // 每小时1000次
        if (!rateLimitService.isAllowed("api_user:" + userId, 1000, 3600)) {
            return false;
        }
        
        // 每分钟50次
        if (!rateLimitService.isAllowed("api_user_minute:" + userId, 50, 60)) {
            return false;
        }
        
        return true;
    }

    /**
     * 检查IP API调用频率
     */
    private boolean checkIpApiRate(String clientIp) {
        // 每小时2000次
        if (!rateLimitService.isAllowed("api_ip:" + clientIp, 2000, 3600)) {
            return false;
        }
        
        // 每分钟100次
        if (!rateLimitService.isAllowed("api_ip_minute:" + clientIp, 100, 60)) {
            return false;
        }
        
        return true;
    }

    /**
     * 检查特定端点频率
     */
    private boolean checkEndpointRate(String userId, String apiPath) {
        // 敏感端点限制更严格
        if (isSensitiveEndpoint(apiPath)) {
            return rateLimitService.isAllowed("api_sensitive:" + userId + ":" + apiPath, 10, 3600);
        }
        
        // 普通端点
        return rateLimitService.isAllowed("api_endpoint:" + userId + ":" + apiPath, 100, 3600);
    }

    /**
     * 检测异常行为模式
     */
    private boolean detectAbnormalPattern(String userId, String apiPath, HttpServletRequest request) {
        try {
            // 1. 检查请求间隔是否过短
            if (isRequestTooFrequent(userId)) {
                return true;
            }
            
            // 2. 检查是否为自动化工具
            if (isAutomatedTool(request)) {
                return true;
            }
            
            // 3. 检查API调用模式是否异常
            if (hasAbnormalCallPattern(userId)) {
                return true;
            }
            
            // 4. 检查是否存在爬虫行为
            if (isCrawlerBehavior(userId, apiPath)) {
                return true;
            }
            
            return false;
            
        } catch (Exception e) {
            log.error("检测异常行为模式异常: {}", e.getMessage());
            return false;
        }
    }

    /**
     * 检查请求是否过于频繁
     */
    private boolean isRequestTooFrequent(String userId) {
        try {
            String key = "api_last_request:" + userId;
            Long lastRequestTime = (Long) redisTemplate.opsForValue().get(key);
            long currentTime = System.currentTimeMillis();
            
            if (lastRequestTime != null) {
                long interval = currentTime - lastRequestTime;
                if (interval < 100) { // 100毫秒内的请求认为过于频繁
                    return true;
                }
            }
            
            redisTemplate.opsForValue().set(key, currentTime, 60, TimeUnit.SECONDS);
            return false;
            
        } catch (Exception e) {
            return false;
        }
    }

    /**
     * 检查是否为自动化工具
     */
    private boolean isAutomatedTool(HttpServletRequest request) {
        String userAgent = request.getHeader("User-Agent");
        if (userAgent == null) {
            return true;
        }
        
        String[] botUserAgents = {
            "bot", "crawler", "spider", "scraper", "curl", "wget", "python", 
            "java", "go-http", "okhttp", "apache-httpclient", "postman"
        };
        
        String lowerUserAgent = userAgent.toLowerCase();
        for (String botAgent : botUserAgents) {
            if (lowerUserAgent.contains(botAgent)) {
                return true;
            }
        }
        
        return false;
    }

    /**
     * 检查API调用模式是否异常
     */
    private boolean hasAbnormalCallPattern(String userId) {
        try {
            String key = "api_pattern:" + userId;
            Map<Object, Object> callPattern = redisTemplate.opsForHash().entries(key);
            
            if (callPattern.isEmpty()) {
                return false;
            }
            
            // 检查是否只调用特定几个API
            if (callPattern.size() < 3) {
                long totalCalls = callPattern.values().stream()
                        .mapToLong(v -> Long.parseLong(v.toString()))
                        .sum();
                
                if (totalCalls > 100) { // 大量调用但API种类很少
                    return true;
                }
            }
            
            return false;
            
        } catch (Exception e) {
            return false;
        }
    }

    /**
     * 检查是否为爬虫行为
     */
    private boolean isCrawlerBehavior(String userId, String apiPath) {
        try {
            // 检查是否频繁访问列表类API
            if (apiPath.contains("/list") || apiPath.contains("/search")) {
                String key = "crawler_check:" + userId;
                Long count = redisTemplate.opsForValue().increment(key);
                redisTemplate.expire(key, 3600, TimeUnit.SECONDS);
                
                return count > 50; // 1小时内超过50次列表查询
            }
            
            return false;
            
        } catch (Exception e) {
            return false;
        }
    }

    /**
     * 检查是否在黑名单中
     */
    private boolean isInBlacklist(String userId, String clientIp) {
        try {
            return redisTemplate.hasKey("blacklist_user:" + userId) ||
                   redisTemplate.hasKey("blacklist_ip:" + clientIp);
        } catch (Exception e) {
            return false;
        }
    }

    /**
     * 判断是否为敏感端点
     */
    private boolean isSensitiveEndpoint(String apiPath) {
        String[] sensitiveEndpoints = {
            "/api/payments/", "/api/admin/", "/api/user/profile",
            "/api/auth/", "/api/upload/", "/api/security/"
        };
        
        for (String endpoint : sensitiveEndpoints) {
            if (apiPath.startsWith(endpoint)) {
                return true;
            }
        }
        
        return false;
    }

    /**
     * 记录API调用
     */
    private void recordApiCall(String userId, String clientIp, String apiPath) {
        try {
            // 记录用户API调用模式
            String patternKey = "api_pattern:" + userId;
            redisTemplate.opsForHash().increment(patternKey, apiPath, 1);
            redisTemplate.expire(patternKey, 3600, TimeUnit.SECONDS);
            
            // 记录API调用统计
            String statsKey = "api_stats:" + LocalDateTime.now().getHour();
            redisTemplate.opsForHash().increment(statsKey, "total", 1);
            redisTemplate.opsForHash().increment(statsKey, "user:" + userId, 1);
            redisTemplate.opsForHash().increment(statsKey, "ip:" + clientIp, 1);
            redisTemplate.opsForHash().increment(statsKey, "endpoint:" + apiPath, 1);
            redisTemplate.expire(statsKey, 3600, TimeUnit.SECONDS);
            
        } catch (Exception e) {
            log.debug("记录API调用统计异常: {}", e.getMessage());
        }
    }

    /**
     * 记录滥用事件
     */
    private void recordAbuseEvent(String userId, String clientIp, String abuseType, String apiPath) {
        try {
            Map<String, Object> event = new HashMap<>();
            event.put("userId", userId);
            event.put("clientIp", clientIp);
            event.put("abuseType", abuseType);
            event.put("apiPath", apiPath);
            event.put("timestamp", LocalDateTime.now().toString());
            
            String eventKey = "api_abuse_event:" + System.currentTimeMillis();
            redisTemplate.opsForValue().set(eventKey, event, 7, TimeUnit.DAYS);
            
            // 增加滥用计数
            String abuseCountKey = "api_abuse_count:" + userId;
            Long abuseCount = redisTemplate.opsForValue().increment(abuseCountKey);
            redisTemplate.expire(abuseCountKey, 24, TimeUnit.HOURS);
            
            // 如果滥用次数过多，加入黑名单
            if (abuseCount >= 10) {
                addToBlacklist(userId, clientIp, "API滥用次数过多");
            }
            
            log.warn("API滥用事件: userId={}, ip={}, type={}, path={}", 
                    userId, clientIp, abuseType, apiPath);
            
        } catch (Exception e) {
            log.error("记录滥用事件异常: {}", e.getMessage(), e);
        }
    }

    /**
     * 添加到黑名单
     */
    public void addToBlacklist(String userId, String clientIp, String reason) {
        try {
            Map<String, Object> blacklistInfo = new HashMap<>();
            blacklistInfo.put("reason", reason);
            blacklistInfo.put("timestamp", LocalDateTime.now().toString());
            
            if (userId != null) {
                redisTemplate.opsForValue().set("blacklist_user:" + userId, blacklistInfo, 
                        24, TimeUnit.HOURS);
            }
            
            if (clientIp != null) {
                redisTemplate.opsForValue().set("blacklist_ip:" + clientIp, blacklistInfo, 
                        24, TimeUnit.HOURS);
            }
            
            log.warn("添加到黑名单: userId={}, ip={}, reason={}", userId, clientIp, reason);
            
        } catch (Exception e) {
            log.error("添加黑名单异常: {}", e.getMessage(), e);
        }
    }

    /**
     * 从黑名单移除
     */
    public void removeFromBlacklist(String userId, String clientIp) {
        try {
            if (userId != null) {
                redisTemplate.delete("blacklist_user:" + userId);
            }
            
            if (clientIp != null) {
                redisTemplate.delete("blacklist_ip:" + clientIp);
            }
            
            log.info("从黑名单移除: userId={}, ip={}", userId, clientIp);
            
        } catch (Exception e) {
            log.error("移除黑名单异常: {}", e.getMessage(), e);
        }
    }

    /**
     * 获取客户端IP
     */
    private String getClientIp(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        
        return ip != null ? ip : "unknown";
    }

    /**
     * 获取API滥用统计
     */
    public Map<String, Object> getAbuseStatistics() {
        try {
            Map<String, Object> stats = new HashMap<>();
            
            // 获取当前小时的统计
            String statsKey = "api_stats:" + LocalDateTime.now().getHour();
            Map<Object, Object> hourlyStats = redisTemplate.opsForHash().entries(statsKey);
            stats.put("hourlyStats", hourlyStats);
            
            // 获取黑名单数量
            long blacklistedUsers = redisTemplate.keys("blacklist_user:*").size();
            long blacklistedIps = redisTemplate.keys("blacklist_ip:*").size();
            stats.put("blacklistedUsers", blacklistedUsers);
            stats.put("blacklistedIps", blacklistedIps);
            
            return stats;
            
        } catch (Exception e) {
            log.error("获取滥用统计异常: {}", e.getMessage(), e);
            return new HashMap<>();
        }
    }
}
